{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# RBAC with RAG\n",
    "This code accompanies the blog [RBAC with RAG - Best of Friends](https://www.elastic.co/search-labs/blog/rbac-and-rag-best-friends)\n",
    "\n",
    "It is a simple demonstration of how users assigned to different groups are able to query the same index pattern, but only retrieve documents they should have access to."
   ],
   "metadata": {
    "id": "LHUQsHIITyJ9"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Indices:\n",
    "- `rbac_rag_demo-data_public` contains data that is not restricted\n",
    "- `rbac_rag_demo-data_senstive` contains data is is restricted to only managers\n",
    "\n",
    "Users:\n",
    "- `engineer_role` will have access to the `rbac_rag_demo-data_public` index\n",
    "\n",
    "> Add blockquote\n",
    "\n",
    "\n",
    "- `manager_role` will have access to both `rbac_rag_demo-data_public` and `rbac_rag_demo-data_sensitive` indices"
   ],
   "metadata": {
    "id": "YzkOllkZydjr"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Environment setup"
   ],
   "metadata": {
    "id": "-keuO8s9VI0A"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Install and import required python libraries"
   ],
   "metadata": {
    "id": "QmVaDHHdyV9B"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "!pip install elasticsearch"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "MV8m_Yb2yUlR",
    "outputId": "2a12455b-e237-4c2e-fa9a-b926f1f23fa1"
   },
   "execution_count": 1,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Requirement already satisfied: elasticsearch in /usr/local/lib/python3.10/dist-packages (8.13.1)\n",
      "Requirement already satisfied: elastic-transport<9,>=8.13 in /usr/local/lib/python3.10/dist-packages (from elasticsearch) (8.13.0)\n",
      "Requirement already satisfied: urllib3<3,>=1.26.2 in /usr/local/lib/python3.10/dist-packages (from elastic-transport<9,>=8.13->elasticsearch) (2.0.7)\n",
      "Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from elastic-transport<9,>=8.13->elasticsearch) (2024.2.2)\n"
     ]
    }
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "from elasticsearch import Elasticsearch\n",
    "from IPython.display import HTML, display\n",
    "from pprint import pprint\n",
    "\n",
    "import getpass"
   ],
   "metadata": {
    "id": "Z2Q1j45gT1M5"
   },
   "execution_count": 2,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Cloud ID and API Key\n",
    "\n",
    "Run the cell below and input your ESS cloud_id and elasticsearch api key\n",
    "\n",
    "This is an existing API key that has access to create, delete, and query indices"
   ],
   "metadata": {
    "id": "cvqOeUfwl-xK"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "cloud_id = getpass.getpass(\"Enter your Elastic Cloud ID: \")\n",
    "api_key = getpass.getpass(\n",
    "    \"Enter your API key (with access to create, delete, and query indices): \"\n",
    ")"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "IW6OP3xSl5yO",
    "outputId": "d4f73bbe-5ee9-4ee4-de0d-4d7a34d698c4"
   },
   "execution_count": 3,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Enter your Elastic Cloud ID: ··········\n",
      "Enter your API key (with access to create, delete, and query indices): ··········\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Elasticsearch Setup"
   ],
   "metadata": {
    "id": "vcwg15wbVB9_"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Create elasticsearch connection for index and user setup"
   ],
   "metadata": {
    "id": "f3MkYRmKUWEM"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "es = Elasticsearch(cloud_id=cloud_id, api_key=api_key)"
   ],
   "metadata": {
    "id": "dUTnXh7GUdQp"
   },
   "execution_count": 4,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Delete demo indices if they previously existed"
   ],
   "metadata": {
    "id": "6Xxkh3oqYFFW"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "# Delete indices\n",
    "def delete_indices():\n",
    "    try:\n",
    "        es.indices.delete(index=\"rbac_rag_demo-data_public\")\n",
    "        print(\"Deleted index: rbac_rag_demo-data_public\")\n",
    "    except Exception as e:\n",
    "        print(f\"Error deleting index rbac_rag_demo-data_public: {str(e)}\")\n",
    "\n",
    "    try:\n",
    "        es.indices.delete(index=\"rbac_rag_demo-data_sensitive\")\n",
    "        print(\"Deleted index: rbac_rag_demo-data_sensitive\")\n",
    "    except Exception as e:\n",
    "        print(f\"Error deleting index rbac_rag_demo-data_sensitive: {str(e)}\")\n",
    "\n",
    "\n",
    "delete_indices()"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "K-tfaBftYME-",
    "outputId": "d57dc00d-8e36-4de5-8c8e-40b216e1a002"
   },
   "execution_count": 5,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Deleted index: rbac_rag_demo-data_public\n",
      "Deleted index: rbac_rag_demo-data_sensitive\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Create and load data into indices\n"
   ],
   "metadata": {
    "id": "TrgFvCU3TuHN"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "# Create indices\n",
    "def create_indices():\n",
    "    # Create data_public index\n",
    "    es.indices.create(\n",
    "        index=\"rbac_rag_demo-data_public\",\n",
    "        ignore=400,\n",
    "        body={\n",
    "            \"settings\": {\"number_of_shards\": 1},\n",
    "            \"mappings\": {\"properties\": {\"info\": {\"type\": \"text\"}}},\n",
    "        },\n",
    "    )\n",
    "\n",
    "    # Create data_sensitive index\n",
    "    es.indices.create(\n",
    "        index=\"rbac_rag_demo-data_sensitive\",\n",
    "        ignore=400,\n",
    "        body={\n",
    "            \"settings\": {\"number_of_shards\": 1},\n",
    "            \"mappings\": {\n",
    "                \"properties\": {\n",
    "                    \"document\": {\"type\": \"text\"},\n",
    "                    \"confidentiality_level\": {\"type\": \"keyword\"},\n",
    "                }\n",
    "            },\n",
    "        },\n",
    "    )\n",
    "\n",
    "\n",
    "# Populate sample data\n",
    "def populate_data():\n",
    "    # Public HR information\n",
    "    public_docs = [\n",
    "        {\"title\": \"Annual leave policies updated.\", \"confidentiality_level\": \"low\"},\n",
    "        {\"title\": \"Remote work guidelines available.\", \"confidentiality_level\": \"low\"},\n",
    "        {\n",
    "            \"title\": \"Health benefits registration period starts next month.\",\n",
    "            \"confidentiality_level\": \"low\",\n",
    "        },\n",
    "    ]\n",
    "    for doc in public_docs:\n",
    "        es.index(index=\"rbac_rag_demo-data_public\", document=doc)\n",
    "\n",
    "    # Sensitive HR information\n",
    "    sensitive_docs = [\n",
    "        {\n",
    "            \"title\": \"Executive compensation details Q2 2024.\",\n",
    "            \"confidentiality_level\": \"high\",\n",
    "        },\n",
    "        {\n",
    "            \"title\": \"Bonus payout structure for all levels.\",\n",
    "            \"confidentiality_level\": \"high\",\n",
    "        },\n",
    "        {\n",
    "            \"title\": \"Employee stock options plan details.\",\n",
    "            \"confidentiality_level\": \"high\",\n",
    "        },\n",
    "    ]\n",
    "    for doc in sensitive_docs:\n",
    "        es.index(index=\"rbac_rag_demo-data_sensitive\", document=doc)\n",
    "\n",
    "\n",
    "create_indices()\n",
    "populate_data()"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "tZJyzuqjTu5A",
    "outputId": "bd977366-e0ca-4034-92f9-dc8b9c24daea"
   },
   "execution_count": 6,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "<ipython-input-6-126412fda511>:4: DeprecationWarning: Passing transport options in the API method is deprecated. Use 'Elasticsearch.options()' instead.\n",
      "  es.indices.create(index=\"rbac_rag_demo-data_public\", ignore=400, body={\n",
      "<ipython-input-6-126412fda511>:16: DeprecationWarning: Passing transport options in the API method is deprecated. Use 'Elasticsearch.options()' instead.\n",
      "  es.indices.create(index=\"rbac_rag_demo-data_sensitive\", ignore=400, body={\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Create two users with different access levels\n"
   ],
   "metadata": {
    "id": "fQFEfAJayVOP"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "# Create roles\n",
    "def create_roles():\n",
    "    # Role for the engineer\n",
    "    es.security.put_role(\n",
    "        name=\"engineer_role\",\n",
    "        body={\n",
    "            \"indices\": [\n",
    "                {\"names\": [\"rbac_rag_demo-data_public\"], \"privileges\": [\"read\"]}\n",
    "            ]\n",
    "        },\n",
    "    )\n",
    "\n",
    "    # Role for the manager\n",
    "    es.security.put_role(\n",
    "        name=\"manager_role\",\n",
    "        body={\n",
    "            \"indices\": [\n",
    "                {\n",
    "                    \"names\": [\n",
    "                        \"rbac_rag_demo-data_public\",\n",
    "                        \"rbac_rag_demo-data_sensitive\",\n",
    "                    ],\n",
    "                    \"privileges\": [\"read\"],\n",
    "                }\n",
    "            ]\n",
    "        },\n",
    "    )\n",
    "\n",
    "\n",
    "# Create users with respective roles\n",
    "def create_users():\n",
    "    # User 'engineer'\n",
    "    es.security.put_user(\n",
    "        username=\"engineer\",\n",
    "        body={\n",
    "            \"password\": \"password123\",\n",
    "            \"roles\": [\"engineer_role\"],\n",
    "            \"full_name\": \"Engineer User\",\n",
    "        },\n",
    "    )\n",
    "\n",
    "    # User 'manager'\n",
    "    es.security.put_user(\n",
    "        username=\"manager\",\n",
    "        body={\n",
    "            \"password\": \"password123\",\n",
    "            \"roles\": [\"manager_role\"],\n",
    "            \"full_name\": \"Manager User\",\n",
    "        },\n",
    "    )\n",
    "\n",
    "\n",
    "create_roles()\n",
    "create_users()"
   ],
   "metadata": {
    "id": "dBSwvhG5xcqz"
   },
   "execution_count": 7,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Test how security roles affect ability to query data"
   ],
   "metadata": {
    "id": "UJyCbx2kyhb6"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Create helper functions"
   ],
   "metadata": {
    "id": "gNg0HvKSVsfw"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Helper functions to query for each user\n",
    "\n",
    "and some output formatting"
   ],
   "metadata": {
    "id": "g1vBpPRbE6A-"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "def get_es_connection(cid, username, password):\n",
    "    return Elasticsearch(cloud_id=cid, basic_auth=(username, password))\n",
    "\n",
    "\n",
    "def query_index(es, index_name, username):\n",
    "    try:\n",
    "        response = es.search(index=index_name, body={\"query\": {\"match_all\": {}}})\n",
    "\n",
    "        # Prepare the message\n",
    "        results_message = f'Results from querying as <span style=\"color: orange;\">{username}:</span><br>'\n",
    "        for hit in response[\"hits\"][\"hits\"]:\n",
    "            confidentiality_level = hit[\"_source\"].get(\"confidentiality_level\", \"N/A\")\n",
    "            index_name = hit.get(\"_index\", \"N/A\")\n",
    "            title = hit[\"_source\"].get(\"title\", \"No title\")\n",
    "\n",
    "            # Set color based on confidentiality level\n",
    "            if confidentiality_level == \"low\":\n",
    "                conf_color = \"lightgreen\"\n",
    "            elif confidentiality_level == \"high\":\n",
    "                conf_color = \"red\"\n",
    "            else:\n",
    "                conf_color = \"black\"\n",
    "\n",
    "            # Set color based on index name\n",
    "            if index_name == \"rbac_rag_demo-data_public\":\n",
    "                index_color = \"lightgreen\"\n",
    "            elif index_name == \"rbac_rag_demo-data_sensitive\":\n",
    "                index_color = \"red\"\n",
    "            else:\n",
    "                index_color = \"black\"  # Default color\n",
    "\n",
    "            results_message += (\n",
    "                f'Index: <span style=\"color: {index_color};\">{index_name}</span>\\t '\n",
    "                f'confidentiality level: <span style=\"color: {conf_color};\">{confidentiality_level}</span> '\n",
    "                f'title: <span style=\"color: lightblue;\">{title}</span><br>'\n",
    "            )\n",
    "\n",
    "        display(HTML(results_message))\n",
    "\n",
    "    except Exception as e:\n",
    "        print(f\"Error accessing {index_name}: {str(e)}\")"
   ],
   "metadata": {
    "id": "5HBlLrrkyh6I"
   },
   "execution_count": 14,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Siumlate querying as an \"engineer\""
   ],
   "metadata": {
    "id": "XromvWf8VoWZ"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "index_pattern = \"rbac_rag_demo-data*\"\n",
    "print(\n",
    "    f\"Each user will log in with their credentials and query the same index pattern: {index_pattern}\\n\\n\"\n",
    ")\n",
    "\n",
    "for user in [\"engineer\", \"manager\"]:\n",
    "    print(f\"Logged in as {user}:\")\n",
    "\n",
    "    es_conn = get_es_connection(cloud_id, user, \"password123\")\n",
    "    results = query_index(es_conn, index_pattern, user)\n",
    "    print(\"\\n\\n\")"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 404
    },
    "id": "iL5udAo0Vnyr",
    "outputId": "07a96f79-96b3-4b82-ecb4-8979e08290c1"
   },
   "execution_count": 15,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Each user will log in with their credentials and query the same index pattern: rbac_rag_demo-data*\n",
      "\n",
      "\n",
      "Logged in as engineer:\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ],
      "text/html": [
       "Results from querying as <span style=\"color: orange;\">engineer:</span><br>Index: <span style=\"color: lightgreen;\">rbac_rag_demo-data_public</span>\t confidentiality level: <span style=\"color: lightgreen;\">low</span> title: <span style=\"color: lightblue;\">Annual leave policies updated.</span><br>Index: <span style=\"color: lightgreen;\">rbac_rag_demo-data_public</span>\t confidentiality level: <span style=\"color: lightgreen;\">low</span> title: <span style=\"color: lightblue;\">Remote work guidelines available.</span><br>Index: <span style=\"color: lightgreen;\">rbac_rag_demo-data_public</span>\t confidentiality level: <span style=\"color: lightgreen;\">low</span> title: <span style=\"color: lightblue;\">Health benefits registration period starts next month.</span><br>"
      ]
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\n",
      "\n",
      "\n",
      "Logged in as manager:\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ],
      "text/html": [
       "Results from querying as <span style=\"color: orange;\">manager:</span><br>Index: <span style=\"color: lightgreen;\">rbac_rag_demo-data_public</span>\t confidentiality level: <span style=\"color: lightgreen;\">low</span> title: <span style=\"color: lightblue;\">Annual leave policies updated.</span><br>Index: <span style=\"color: lightgreen;\">rbac_rag_demo-data_public</span>\t confidentiality level: <span style=\"color: lightgreen;\">low</span> title: <span style=\"color: lightblue;\">Remote work guidelines available.</span><br>Index: <span style=\"color: lightgreen;\">rbac_rag_demo-data_public</span>\t confidentiality level: <span style=\"color: lightgreen;\">low</span> title: <span style=\"color: lightblue;\">Health benefits registration period starts next month.</span><br>Index: <span style=\"color: red;\">rbac_rag_demo-data_sensitive</span>\t confidentiality level: <span style=\"color: red;\">high</span> title: <span style=\"color: lightblue;\">Executive compensation details Q2 2024.</span><br>Index: <span style=\"color: red;\">rbac_rag_demo-data_sensitive</span>\t confidentiality level: <span style=\"color: red;\">high</span> title: <span style=\"color: lightblue;\">Bonus payout structure for all levels.</span><br>Index: <span style=\"color: red;\">rbac_rag_demo-data_sensitive</span>\t confidentiality level: <span style=\"color: red;\">high</span> title: <span style=\"color: lightblue;\">Employee stock options plan details.</span><br>"
      ]
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\n",
      "\n",
      "\n"
     ]
    }
   ]
  }
 ]
}
